# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Handles control flow statements: while, for, if."""

import gast

from nvidia.dali._autograph.core import converter
from nvidia.dali._autograph.lang import directives
from nvidia.dali._autograph.pyct import anno
from nvidia.dali._autograph.pyct import cfg
from nvidia.dali._autograph.pyct import origin_info
from nvidia.dali._autograph.pyct import parser
from nvidia.dali._autograph.pyct import qual_names
from nvidia.dali._autograph.pyct import templates
from nvidia.dali._autograph.pyct.static_analysis import activity
from nvidia.dali._autograph.pyct.static_analysis import annos
from nvidia.dali._autograph.pyct.static_analysis import liveness
from nvidia.dali._autograph.pyct.static_analysis import reaching_definitions
from nvidia.dali._autograph.pyct.static_analysis import reaching_fndefs


class _Function(object):
    scope = None


class ControlFlowTransformer(converter.Base):
    """Transforms control flow structures like loops an conditionals."""

    def visit_Lambda(self, node):
        with self.state[_Function] as fn:
            fn.scope = anno.getanno(node, anno.Static.SCOPE)
            return self.generic_visit(node)

    def visit_FunctionDef(self, node):
        with self.state[_Function] as fn:
            fn.scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
            return self.generic_visit(node)

    def _create_nonlocal_declarations(self, vars_):
        vars_ = set(vars_)
        results = []
        global_vars = self.state[_Function].scope.globals & vars_

        if global_vars:
            results.append(gast.Global([str(v) for v in global_vars]))

        nonlocal_vars = [v for v in vars_ if not v.is_composite() and v not in global_vars]
        if nonlocal_vars:
            results.append(gast.Nonlocal([str(v) for v in nonlocal_vars]))

        return results

    def _create_state_functions(self, block_vars, nonlocal_declarations, getter_name, setter_name):
        if not block_vars:
            template = """
        def getter_name():
          return ()
        def setter_name(block_vars):
          pass
      """
            return templates.replace(template, getter_name=getter_name, setter_name=setter_name)

        guarded_block_vars = []
        for v in block_vars:
            if v.is_simple():
                guarded_block_vars.append(v)
            else:
                guarded_block_vars.append(
                    templates.replace_as_expression(
                        "ag__.ldu(lambda: var_, name)",
                        var_=v,
                        name=gast.Constant(str(v), kind=None),
                    )
                )

        template = """
      def getter_name():
        return guarded_state_vars,
      def setter_name(vars_):
        nonlocal_declarations
        state_vars, = vars_
    """
        return templates.replace(
            template,
            nonlocal_declarations=nonlocal_declarations,
            getter_name=getter_name,
            guarded_state_vars=guarded_block_vars,
            setter_name=setter_name,
            state_vars=tuple(block_vars),
        )

    def _create_loop_options(self, node):
        if not anno.hasanno(node, anno.Basic.DIRECTIVES):
            return gast.Dict([], [])

        loop_directives = anno.getanno(node, anno.Basic.DIRECTIVES)
        if directives.set_loop_options not in loop_directives:
            return gast.Dict([], [])

        opts_dict = loop_directives[directives.set_loop_options]
        str_keys, values = zip(*opts_dict.items())
        keys = [gast.Constant(s, kind=None) for s in str_keys]
        values = list(values)  # ast and gast don't play well with tuples.
        return gast.Dict(keys, values)

    def _create_undefined_assigns(self, undefined_symbols):
        assignments = []
        for s in undefined_symbols:
            template = """
        var = ag__.Undefined(symbol_name)
      """
            assignments += templates.replace(
                template, var=s, symbol_name=gast.Constant(s.ssf(), kind=None)
            )
        return assignments

    def _get_block_basic_vars(self, modified, live_in, live_out):
        nonlocals = self.state[_Function].scope.nonlocals
        basic_scope_vars = []
        for s in modified:
            if s.is_composite():
                # TODO(mdan): Raise an error when this happens for a TF scope.
                continue
            # Variables not live into or out of the scope are considered local to the
            # scope.
            if s in live_in or s in live_out or s in nonlocals:
                basic_scope_vars.append(s)
            continue
        return frozenset(basic_scope_vars)

    def _get_block_composite_vars(self, modified, live_in):
        # The scope variables corresponding to composite symbols (e.g. `self.x`).
        composite_scope_vars = []
        for s in modified:
            if not s.is_composite():
                continue
            # Mutations made to objects created inside the scope will appear as writes
            # to composite symbols. Because these mutations appear as modifications
            # made to composite symbols, we check whether the composite's parent is
            # actually live into the scope.
            # Example:
            #   while cond:
            #     x = Foo()
            #     x.foo = 2 * x.foo  # x.foo is live into the scope, but x is not.
            #
            # Note that some parents might not be symbols - for example, in x['foo'],
            # 'foo' is a parent, but it's a literal, not a symbol. We don't check the
            # liveness of literals.
            support_set_symbols = tuple(sss for sss in s.support_set if sss.is_symbol())
            if not all(sss in live_in for sss in support_set_symbols):
                continue
            composite_scope_vars.append(s)
        return frozenset(composite_scope_vars)

    def _get_block_vars(self, node, modified):
        """Determines the variables affected inside a control flow statement."""
        defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN)
        live_in = anno.getanno(node, anno.Static.LIVE_VARS_IN)
        live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)
        fn_scope = self.state[_Function].scope

        basic_scope_vars = self._get_block_basic_vars(modified, live_in, live_out)
        composite_scope_vars = self._get_block_composite_vars(modified, live_in)
        scope_vars = tuple(basic_scope_vars | composite_scope_vars)

        # Variables that are modified inside the scope, but not defined
        # before entering it. Only simple variables must be defined. The
        # composite ones will be implicitly checked at runtime.
        possibly_undefined = modified - defined_in - fn_scope.globals - fn_scope.nonlocals
        undefined = tuple(v for v in possibly_undefined if not v.is_composite())

        # Variables that are modified inside the scope, and depend on values outside
        # it.
        input_only = basic_scope_vars & live_in - live_out

        # Place the outputs first, then sort lexicographically.
        scope_vars = sorted(scope_vars, key=lambda v: (v in input_only, v))
        nouts = len(scope_vars) - len(input_only)

        return scope_vars, undefined, nouts

    def visit_If(self, node):
        node = self.generic_visit(node)
        body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
        orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE)

        cond_vars, undefined, nouts = self._get_block_vars(
            node, body_scope.bound | orelse_scope.bound
        )

        undefined_assigns = self._create_undefined_assigns(undefined)

        nonlocal_declarations = self._create_nonlocal_declarations(cond_vars)

        reserved = body_scope.referenced | orelse_scope.referenced
        state_getter_name = self.ctx.namer.new_symbol("get_state", reserved)
        state_setter_name = self.ctx.namer.new_symbol("set_state", reserved)
        state_functions = self._create_state_functions(
            cond_vars, nonlocal_declarations, state_getter_name, state_setter_name
        )

        orelse_body = node.orelse
        if not orelse_body:
            orelse_body = [gast.Pass()]

        template = """
      state_functions
      def body_name():
        nonlocal_declarations
        body
      def orelse_name():
        nonlocal_declarations
        orelse
      undefined_assigns
      ag__.if_stmt(
        test,
        body_name,
        orelse_name,
        state_getter_name,
        state_setter_name,
        (symbol_names,),
        nouts)
    """
        new_nodes = templates.replace(
            template,
            body=node.body,
            body_name=self.ctx.namer.new_symbol("if_body", reserved),
            orelse=orelse_body,
            orelse_name=self.ctx.namer.new_symbol("else_body", reserved),
            nonlocal_declarations=nonlocal_declarations,
            nouts=gast.Constant(nouts, kind=None),
            state_functions=state_functions,
            state_getter_name=state_getter_name,
            state_setter_name=state_setter_name,
            symbol_names=tuple(gast.Constant(str(s), kind=None) for s in cond_vars),
            test=node.test,
            undefined_assigns=undefined_assigns,
        )
        origin_info.copy_origin(node, new_nodes[-1])
        return new_nodes

    def visit_While(self, node):
        node = self.generic_visit(node)
        body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)

        loop_vars, undefined, _ = self._get_block_vars(node, body_scope.bound)

        undefined_assigns = self._create_undefined_assigns(undefined)

        nonlocal_declarations = self._create_nonlocal_declarations(loop_vars)

        reserved = body_scope.referenced
        state_getter_name = self.ctx.namer.new_symbol("get_state", reserved)
        state_setter_name = self.ctx.namer.new_symbol("set_state", reserved)
        state_functions = self._create_state_functions(
            loop_vars, nonlocal_declarations, state_getter_name, state_setter_name
        )

        opts = self._create_loop_options(node)

        template = """
      state_functions
      def body_name():
        nonlocal_declarations
        body
      def test_name():
        return test
      undefined_assigns
      ag__.while_stmt(
          test_name,
          body_name,
          state_getter_name,
          state_setter_name,
          (symbol_names,),
          opts)
    """
        new_nodes = templates.replace(
            template,
            body=node.body,
            body_name=self.ctx.namer.new_symbol("loop_body", reserved),
            nonlocal_declarations=nonlocal_declarations,
            opts=opts,
            state_functions=state_functions,
            state_getter_name=state_getter_name,
            state_setter_name=state_setter_name,
            symbol_names=tuple(gast.Constant(str(s), kind=None) for s in loop_vars),
            test=node.test,
            test_name=self.ctx.namer.new_symbol("loop_test", reserved),
            undefined_assigns=undefined_assigns,
        )
        origin_info.copy_origin(node, new_nodes[-1])
        return new_nodes

    def visit_For(self, node):
        node = self.generic_visit(node)
        body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
        iter_scope = anno.getanno(node, annos.NodeAnno.ITERATE_SCOPE)

        loop_vars, undefined, _ = self._get_block_vars(node, body_scope.bound | iter_scope.bound)

        undefined_assigns = self._create_undefined_assigns(undefined)

        nonlocal_declarations = self._create_nonlocal_declarations(loop_vars)

        reserved = body_scope.referenced | iter_scope.referenced
        state_getter_name = self.ctx.namer.new_symbol("get_state", reserved)
        state_setter_name = self.ctx.namer.new_symbol("set_state", reserved)
        state_functions = self._create_state_functions(
            loop_vars, nonlocal_declarations, state_getter_name, state_setter_name
        )

        opts = self._create_loop_options(node)
        opts.keys.append(gast.Constant("iterate_names", kind=None))
        opts.values.append(
            gast.Constant(parser.unparse(node.target, include_encoding_marker=False), kind=None)
        )

        if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST):
            extra_test = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST)
            extra_test_name = self.ctx.namer.new_symbol("extra_test", reserved)
            template = """
        def extra_test_name():
          nonlocal_declarations
          return extra_test_expr
      """
            extra_test_function = templates.replace(
                template,
                extra_test_expr=extra_test,
                extra_test_name=extra_test_name,
                loop_vars=loop_vars,
                nonlocal_declarations=nonlocal_declarations,
            )
        else:
            extra_test_name = parser.parse_expression("None")
            extra_test_function = []

        # iterate_arg_name holds a single arg with the iterates, which may be a
        # tuple.
        iterate_arg_name = self.ctx.namer.new_symbol("itr", reserved)
        template = """
      iterates = iterate_arg_name
    """
        iterate_expansion = templates.replace(
            template, iterate_arg_name=iterate_arg_name, iterates=node.target
        )
        origin_info.copy_origin(node, iterate_expansion)

        template = """
      state_functions
      def body_name(iterate_arg_name):
        nonlocal_declarations
        iterate_expansion
        body
      extra_test_function
      undefined_assigns
      ag__.for_stmt(
          iterated,
          extra_test_name,
          body_name,
          state_getter_name,
          state_setter_name,
          (symbol_names,),
          opts)
    """
        new_nodes = templates.replace(
            template,
            body=node.body,
            body_name=self.ctx.namer.new_symbol("loop_body", reserved),
            extra_test_function=extra_test_function,
            extra_test_name=extra_test_name,
            iterate_arg_name=iterate_arg_name,
            iterate_expansion=iterate_expansion,
            iterated=node.iter,
            nonlocal_declarations=nonlocal_declarations,
            opts=opts,
            symbol_names=tuple(gast.Constant(str(s), kind=None) for s in loop_vars),
            state_functions=state_functions,
            state_getter_name=state_getter_name,
            state_setter_name=state_setter_name,
            undefined_assigns=undefined_assigns,
        )
        origin_info.copy_origin(node, new_nodes[-1])
        return new_nodes


class AnnotatedDef(reaching_definitions.Definition):
    def __init__(self):
        super(AnnotatedDef, self).__init__()
        self.directives = {}


def transform(node, ctx):
    graphs = cfg.build(node)
    node = qual_names.resolve(node)
    node = activity.resolve(node, ctx, None)
    node = reaching_definitions.resolve(node, ctx, graphs)
    node = reaching_fndefs.resolve(node, ctx, graphs)
    node = liveness.resolve(node, ctx, graphs)

    node = ControlFlowTransformer(ctx).visit(node)
    return node
